package cn.js189.gateway.filter;

import cn.js189.common.util.ServletUtils;
import cn.js189.common.util.StringUtils;
import cn.js189.gateway.config.ChannelAuthProperties;
import cn.js189.gateway.utils.InterCountUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.http.HttpMethod;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.stereotype.Component;
import org.springframework.web.reactive.function.server.HandlerStrategies;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import javax.annotation.Resource;
import java.nio.charset.StandardCharsets;

/**
 * 网关鉴权 demo
 *
 * @author hxd
 */
@Slf4j
@Component
public class ChannelAuthFilter implements GlobalFilter, GatewayFilter, Ordered {
	
	// 排除过滤的 uri 地址，nacos自行添加
	@Resource
	private ChannelAuthProperties channelAuthProperties;
	
	@Override
	public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
		ServerHttpRequest request = exchange.getRequest();
		ServerHttpRequest.Builder mutate = request.mutate();
		String method = request.getMethodValue();
		String contentType = request.getHeaders().getFirst("Content-Type");
		String url = request.getURI().getPath();
		// 跳过不需要验证的路径
		if (StringUtils.matches(url, channelAuthProperties.getWhites())) {
			return chain.filter(exchange);
		}
		
		if (null != contentType && HttpMethod.POST.name().equalsIgnoreCase(method) && contentType.contains("application/json")) {
			return DataBufferUtils.join(exchange.getRequest().getBody()).flatMap(dataBuffer -> {
				// 取出body中的参数
				byte[] bytes = new byte[dataBuffer.readableByteCount()];
				dataBuffer.read(bytes);
				DataBufferUtils.release(dataBuffer);
				Flux<DataBuffer> cachedFlux = Flux.defer(() -> {
					DataBuffer buffer = exchange.getResponse().bufferFactory().wrap(bytes);
					DataBufferUtils.retain(buffer);
					return Mono.just(buffer);
				});
				
				String bodyString = new String(bytes, StandardCharsets.UTF_8);
				// 接口请求次数统计
				InterCountUtil.saveInterCount(url,bodyString);
				
				final ServerHttpRequest mutatedRequest = new ServerHttpRequestDecorator(exchange.getRequest()) {
					@Override
					public Flux<DataBuffer> getBody() {
						return cachedFlux;
					}
				};
				final ServerWebExchange mutatedExchange = exchange.mutate().request(mutatedRequest).build();
				
				// 使用默认的messageReaders读取正文字符串
				return ServerRequest.create(mutatedExchange,
						HandlerStrategies.withDefaults().messageReaders()).bodyToMono(String.class).then(chain.filter(mutatedExchange));
			});
		}
		return chain.filter(exchange.mutate().request(mutate.build()).build());
	}
	
	/**
	 * 添加请求头
	 *
	 * @param mutate mutate
	 * @param name   请求头key
	 * @param value  请求头值
	 */
	private void addHeader(ServerHttpRequest.Builder mutate, String name, Object value) {
		if (value == null) {
			return;
		}
		String valueStr = value.toString();
		String valueEncode = ServletUtils.urlEncode(valueStr);
		mutate.header(name, valueEncode);
	}
	
	@Override
	public int getOrder() {
		return Ordered.HIGHEST_PRECEDENCE;
	}
	
}